// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "src/tint/lang/wgsl/ast/transform/spirv_atomic.h"

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#include "src/tint/lang/core/type/reference.h"
#include "src/tint/lang/wgsl/program/clone_context.h"
#include "src/tint/lang/wgsl/program/program_builder.h"
#include "src/tint/lang/wgsl/resolver/resolve.h"
#include "src/tint/lang/wgsl/sem/block_statement.h"
#include "src/tint/lang/wgsl/sem/function.h"
#include "src/tint/lang/wgsl/sem/index_accessor_expression.h"
#include "src/tint/lang/wgsl/sem/member_accessor_expression.h"
#include "src/tint/lang/wgsl/sem/statement.h"
#include "src/tint/utils/containers/map.h"
#include "src/tint/utils/containers/unique_vector.h"
#include "src/tint/utils/rtti/switch.h"

TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::SpirvAtomic);
TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::SpirvAtomic::Stub);

namespace tint::ast::transform {

using namespace tint::number_suffixes;  // NOLINT

/// PIMPL state for the transform
struct SpirvAtomic::State {
  private:
    /// A struct that has been forked because a subset of members were made atomic.
    struct ForkedStruct {
        Symbol name;
        std::unordered_set<size_t> atomic_members;
    };

    /// The source program
    const Program* const src;
    /// The target program builder
    ProgramBuilder b;
    /// The clone context
    program::CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
    std::unordered_map<const core::type::Struct*, ForkedStruct> forked_structs;
    std::unordered_set<const sem::Variable*> atomic_variables;
    UniqueVector<const sem::ValueExpression*, 8> atomic_expressions;

  public:
    /// Constructor
    /// @param program the source program
    explicit State(const Program* program) : src(program) {}

    /// Runs the transform
    /// @returns the new program or SkipTransform if the transform is not required
    ApplyResult Run() {
        bool made_changes = false;

        // Look for stub functions generated by the SPIR-V reader, which are used as placeholders
        // for atomic builtin calls.
        for (auto* fn : ctx.src->AST().Functions()) {
            if (auto* stub = GetAttribute<Stub>(fn->attributes)) {
                auto* sem = ctx.src->Sem().Get(fn);

                for (auto* call : sem->CallSites()) {
                    // The first argument is always the atomic.
                    // The stub passes this by value, whereas the builtin wants a pointer.
                    // Take the address of the atomic argument.
                    auto& args = call->Declaration()->args;
                    auto out_args = ctx.Clone(args);
                    out_args[0] = b.AddressOf(out_args[0]);

                    // Replace all callsites of this stub to a call to the real builtin
                    if (stub->builtin == core::Function::kAtomicCompareExchangeWeak) {
                        // atomicCompareExchangeWeak returns a struct, so insert a call to it above
                        // the current statement, and replace the current call with the struct's
                        // `old_value` member.
                        auto* block = call->Stmt()->Block()->Declaration();
                        auto old_value = b.Symbols().New("old_value");
                        auto old_value_decl = b.Decl(b.Let(
                            old_value,
                            b.MemberAccessor(b.Call(core::str(stub->builtin), std::move(out_args)),
                                             "old_value")));
                        ctx.InsertBefore(block->statements, call->Stmt()->Declaration(),
                                         old_value_decl);
                        ctx.Replace(call->Declaration(), b.Expr(old_value));
                    } else {
                        ctx.Replace(call->Declaration(),
                                    b.Call(core::str(stub->builtin), std::move(out_args)));
                    }

                    // Keep track of this expression. We'll need to modify the root identifier /
                    // structure to be atomic.
                    atomic_expressions.Add(ctx.src->Sem().GetVal(args[0]));
                }

                // Remove the stub from the output program
                ctx.Remove(ctx.src->AST().GlobalDeclarations(), fn);
                made_changes = true;
            }
        }

        if (!made_changes) {
            return SkipTransform;
        }

        // Transform all variables and structure members that were used in atomic operations as
        // atomic types. This propagates up originating expression chains.
        ProcessAtomicExpressions();

        // If we need to change structure members, then fork them.
        if (!forked_structs.empty()) {
            ctx.ReplaceAll([&](const Struct* str) {
                // Is `str` a structure we need to fork?
                auto* str_ty = ctx.src->Sem().Get(str);
                if (auto it = forked_structs.find(str_ty); it != forked_structs.end()) {
                    const auto& forked = it->second;

                    // Re-create the structure swapping in the atomic-flavoured members
                    tint::Vector<const StructMember*, 8> members;
                    members.Reserve(str->members.Length());
                    for (size_t i = 0; i < str->members.Length(); i++) {
                        auto* member = str->members[i];
                        if (forked.atomic_members.count(i)) {
                            auto type = AtomicTypeFor(ctx.src->Sem().Get(member)->Type());
                            auto name = member->name->symbol.Name();
                            members.Push(b.Member(name, type, ctx.Clone(member->attributes)));
                        } else {
                            members.Push(ctx.Clone(member));
                        }
                    }
                    b.Structure(forked.name, std::move(members));
                }
                return nullptr;
            });
        }

        // Replace assignments and decls from atomic variables with atomicLoads, and assignments to
        // atomic variables with atomicStores.
        ReplaceLoadsAndStores();

        ctx.Clone();
        return resolver::Resolve(b);
    }

  private:
    ForkedStruct& Fork(const core::type::Struct* str) {
        auto& forked = forked_structs[str];
        if (!forked.name.IsValid()) {
            forked.name = b.Symbols().New(str->Name().Name() + "_atomic");
        }
        return forked;
    }

    void ProcessAtomicExpressions() {
        for (size_t i = 0; i < atomic_expressions.Length(); i++) {
            Switch(
                atomic_expressions[i]->UnwrapLoad(),  //
                [&](const sem::VariableUser* user) {
                    auto* v = user->Variable()->Declaration();
                    if (v->type && atomic_variables.emplace(user->Variable()).second) {
                        ctx.Replace(v->type.expr, b.Expr(AtomicTypeFor(user->Variable()->Type())));
                    }
                    if (auto* ctor = user->Variable()->Initializer()) {
                        atomic_expressions.Add(ctor);
                    }
                },
                [&](const sem::StructMemberAccess* access) {
                    // Fork the struct (the first time) and mark member(s) that need to be made
                    // atomic.
                    auto* member = access->Member();
                    Fork(member->Struct()).atomic_members.emplace(member->Index());
                    atomic_expressions.Add(access->Object());
                },
                [&](const sem::IndexAccessorExpression* index) {
                    atomic_expressions.Add(index->Object());
                },
                [&](const sem::ValueExpression* e) {
                    if (auto* unary = e->Declaration()->As<UnaryOpExpression>()) {
                        atomic_expressions.Add(ctx.src->Sem().GetVal(unary->expr));
                    }
                });
        }
    }

    Type AtomicTypeFor(const core::type::Type* ty) {
        return Switch(
            ty,  //
            [&](const core::type::I32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); },
            [&](const core::type::U32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); },
            [&](const core::type::Struct* str) { return b.ty(Fork(str).name); },
            [&](const core::type::Array* arr) {
                if (arr->Count()->Is<core::type::RuntimeArrayCount>()) {
                    return b.ty.array(AtomicTypeFor(arr->ElemType()));
                }
                auto count = arr->ConstantCount();
                if (!count) {
                    ctx.dst->Diagnostics().add_error(
                        diag::System::Transform,
                        "the SpirvAtomic transform does not currently support array counts that "
                        "use override values");
                    count = 1;
                }
                return b.ty.array(AtomicTypeFor(arr->ElemType()), u32(count.value()));
            },
            [&](const core::type::Pointer* ptr) {
                return b.ty.ptr(ptr->AddressSpace(), AtomicTypeFor(ptr->StoreType()),
                                ptr->Access());
            },
            [&](const core::type::Reference* ref) { return AtomicTypeFor(ref->StoreType()); },
            [&](Default) {
                TINT_ICE() << "unhandled type: " << ty->FriendlyName();
                return Type{};
            });
    }

    void ReplaceLoadsAndStores() {
        // Returns true if 'e' is a reference to an atomic variable or struct member
        auto is_ref_to_atomic_var = [&](const sem::ValueExpression* e) {
            if (tint::Is<core::type::Reference>(e->Type()) && e->RootIdentifier() &&
                (atomic_variables.count(e->RootIdentifier()) != 0)) {
                // If it's a struct member, make sure it's one we marked as atomic
                if (auto* ma = e->As<sem::StructMemberAccess>()) {
                    auto it = forked_structs.find(ma->Member()->Struct());
                    if (it != forked_structs.end()) {
                        auto& forked = it->second;
                        return forked.atomic_members.count(ma->Member()->Index()) != 0;
                    }
                }
                return true;
            }
            return false;
        };

        // Look for loads and stores via assignments and decls of atomic variables we've collected
        // so far, and replace them with atomicLoad and atomicStore.
        for (auto* atomic_var : atomic_variables) {
            for (auto* vu : atomic_var->Users()) {
                Switch(
                    vu->Stmt()->Declaration(),
                    [&](const AssignmentStatement* assign) {
                        auto* sem_lhs = ctx.src->Sem().GetVal(assign->lhs);
                        if (is_ref_to_atomic_var(sem_lhs)) {
                            ctx.Replace(assign, [=] {
                                auto* lhs = ctx.CloneWithoutTransform(assign->lhs);
                                auto* rhs = ctx.CloneWithoutTransform(assign->rhs);
                                auto* call = b.Call(core::str(core::Function::kAtomicStore),
                                                    b.AddressOf(lhs), rhs);
                                return b.CallStmt(call);
                            });
                            return;
                        }

                        auto sem_rhs = ctx.src->Sem().GetVal(assign->rhs);
                        if (is_ref_to_atomic_var(sem_rhs->UnwrapLoad())) {
                            ctx.Replace(assign->rhs, [=] {
                                auto* rhs = ctx.CloneWithoutTransform(assign->rhs);
                                return b.Call(core::str(core::Function::kAtomicLoad),
                                              b.AddressOf(rhs));
                            });
                            return;
                        }
                    },
                    [&](const VariableDeclStatement* decl) {
                        auto* var = decl->variable;
                        if (auto* sem_init = ctx.src->Sem().GetVal(var->initializer)) {
                            if (is_ref_to_atomic_var(sem_init->UnwrapLoad())) {
                                ctx.Replace(var->initializer, [=] {
                                    auto* rhs = ctx.CloneWithoutTransform(var->initializer);
                                    return b.Call(core::str(core::Function::kAtomicLoad),
                                                  b.AddressOf(rhs));
                                });
                                return;
                            }
                        }
                    });
            }
        }
    }
};

SpirvAtomic::SpirvAtomic() = default;
SpirvAtomic::~SpirvAtomic() = default;

SpirvAtomic::Stub::Stub(GenerationID pid, NodeID nid, core::Function b)
    : Base(pid, nid, tint::Empty), builtin(b) {}
SpirvAtomic::Stub::~Stub() = default;
std::string SpirvAtomic::Stub::InternalName() const {
    return "@internal(spirv-atomic " + std::string(core::str(builtin)) + ")";
}

const SpirvAtomic::Stub* SpirvAtomic::Stub::Clone(ast::CloneContext& ctx) const {
    return ctx.dst->ASTNodes().Create<SpirvAtomic::Stub>(ctx.dst->ID(), ctx.dst->AllocateNodeID(),
                                                         builtin);
}

Transform::ApplyResult SpirvAtomic::Apply(const Program* src, const DataMap&, DataMap&) const {
    return State{src}.Run();
}

}  // namespace tint::ast::transform
